Skip to content

Conversation

@noemotiovon
Copy link
Contributor

  • Refactor Ascend GEGLU kernels to use flatten 1D grid-stride loop pattern instead of row-based tiling approach for better performance
  • Simplify block size calculation using compute_default_tiling_strategy
  • Align type conversion logic with GPU version for consistency
  • Update test tolerances for NPU bfloat16 (1e4) to handle precision differences

Hardware Type: Ascend 910B4

  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@noemotiovon
Copy link
Contributor Author

Benchmark:

**************************************
     BENCHMARKING SPEED for GEGLU
**************************************
[WARNING] Please DO NOT tune args ['num_warps']!
[WARNING] Please DO NOT tune args ['num_warps']!
********** Benchmark Data **********
[
  {
    "kernel_name": "geglu",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      41.29764175415039,
      77.97412109375,
      153.22802734375,
      305.4518127441406
    ],
    "y_values_20": [
      41.29764175415039,
      77.97412109375,
      153.22802734375,
      305.4518127441406
    ],
    "y_values_80": [
      41.29764175415039,
      77.97412109375,
      153.22802734375,
      305.4518127441406
    ],
    "timestamp": "2026-01-20 02:12:01",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"bsz\": 8, \"hidden_size\": 4096, \"intermediate_size\": 11008, \"hidden_act\": \"gelu_pytorch_tanh\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "geglu",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      43.27560043334961,
      82.03520202636719,
      159.79339599609375,
      318.13525390625
    ],
    "y_values_20": [
      43.27560043334961,
      82.03520202636719,
      159.79339599609375,
      318.13525390625
    ],
    "y_values_80": [
      43.27560043334961,
      82.03520202636719,
      159.79339599609375,
      318.13525390625
    ],
    "timestamp": "2026-01-20 02:12:10",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"bsz\": 8, \"hidden_size\": 4096, \"intermediate_size\": 11008, \"hidden_act\": \"gelu_pytorch_tanh\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "geglu",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      10.68362045288086,
      20.787700653076172,
      41.33747863769531,
      82.21530151367188
    ],
    "y_values_20": [
      10.68362045288086,
      20.787700653076172,
      41.33747863769531,
      82.21530151367188
    ],
    "y_values_80": [
      10.68362045288086,
      20.787700653076172,
      41.33747863769531,
      82.21530151367188
    ],
    "timestamp": "2026-01-20 02:12:15",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"bsz\": 8, \"hidden_size\": 4096, \"intermediate_size\": 11008, \"hidden_act\": \"gelu_pytorch_tanh\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "geglu",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      11.14247989654541,
      21.72364044189453,
      43.50083923339844,
      86.65560150146484
    ],
    "y_values_20": [
      11.14247989654541,
      21.72364044189453,
      43.50083923339844,
      86.65560150146484
    ],
    "y_values_80": [
      11.14247989654541,
      21.72364044189453,
      43.50083923339844,
      86.65560150146484
    ],
    "timestamp": "2026-01-20 02:12:20",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"bsz\": 8, \"hidden_size\": 4096, \"intermediate_size\": 11008, \"hidden_act\": \"gelu_pytorch_tanh\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "geglu",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      27.273399353027344,
      50.27893829345703,
      98.00166320800781,
      193.2967987060547
    ],
    "y_values_20": [
      27.273399353027344,
      50.27893829345703,
      98.00166320800781,
      193.2967987060547
    ],
    "y_values_80": [
      27.273399353027344,
      50.27893829345703,
      98.00166320800781,
      193.2967987060547
    ],
    "timestamp": "2026-01-20 02:12:27",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"bsz\": 8, \"hidden_size\": 4096, \"intermediate_size\": 11008, \"hidden_act\": \"gelu_pytorch_tanh\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "geglu",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      28.78070068359375,
      53.18796157836914,
      102.28961944580078,
      204.07730102539062
    ],
    "y_values_20": [
      28.78070068359375,
      53.18796157836914,
      102.28961944580078,
      204.07730102539062
    ],
    "y_values_80": [
      28.78070068359375,
      53.18796157836914,
      102.28961944580078,
      204.07730102539062
    ],
    "timestamp": "2026-01-20 02:12:35",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"bsz\": 8, \"hidden_size\": 4096, \"intermediate_size\": 11008, \"hidden_act\": \"gelu_pytorch_tanh\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  }
]
**************************************
     BENCHMARKING MEMORY for GEGLU
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "geglu",
    "kernel_provider": "liger",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      1910.0087890625,
      3092.0078125,
      5668.0078125,
      10820.0078125
    ],
    "y_values_20": [
      1910.0087890625,
      3092.0078125,
      5668.0078125,
      10820.0078125
    ],
    "y_values_80": [
      1910.0087890625,
      3092.0078125,
      5668.0078125,
      10820.0078125
    ],
    "timestamp": "2026-01-20 02:12:40",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"bsz\": 8, \"hidden_size\": 4096, \"intermediate_size\": 11008, \"hidden_act\": \"gelu_pytorch_tanh\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "geglu",
    "kernel_provider": "huggingface",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      2082.00927734375,
      3436.00830078125,
      6356.00830078125,
      12196.0078125
    ],
    "y_values_20": [
      2082.00927734375,
      3436.00830078125,
      6356.00830078125,
      12196.0078125
    ],
    "y_values_80": [
      2082.00927734375,
      3436.00830078125,
      6356.00830078125,
      12196.0078125
    ],
    "timestamp": "2026-01-20 02:12:49",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"bsz\": 8, \"hidden_size\": 4096, \"intermediate_size\": 11008, \"hidden_act\": \"gelu_pytorch_tanh\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "geglu",
    "kernel_provider": "liger",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      922.0048828125,
      1566.0048828125,
      2854.0048828125,
      5430.0048828125
    ],
    "y_values_20": [
      922.0048828125,
      1566.0048828125,
      2854.0048828125,
      5430.0048828125
    ],
    "y_values_80": [
      922.0048828125,
      1566.0048828125,
      2854.0048828125,
      5430.0048828125
    ],
    "timestamp": "2026-01-20 02:12:56",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"bsz\": 8, \"hidden_size\": 4096, \"intermediate_size\": 11008, \"hidden_act\": \"gelu_pytorch_tanh\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "geglu",
    "kernel_provider": "huggingface",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      1094.00537109375,
      1910.00537109375,
      3542.00537109375,
      6806.00537109375
    ],
    "y_values_20": [
      1094.00537109375,
      1910.00537109375,
      3542.00537109375,
      6806.00537109375
    ],
    "y_values_80": [
      1094.00537109375,
      1910.00537109375,
      3542.00537109375,
      6806.00537109375
    ],
    "timestamp": "2026-01-20 02:13:00",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"bsz\": 8, \"hidden_size\": 4096, \"intermediate_size\": 11008, \"hidden_act\": \"gelu_pytorch_tanh\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "geglu",
    "kernel_provider": "liger",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      1910.0087890625,
      3092.0078125,
      5668.0078125,
      10820.0078125
    ],
    "y_values_20": [
      1910.0087890625,
      3092.0078125,
      5668.0078125,
      10820.0078125
    ],
    "y_values_80": [
      1910.0087890625,
      3092.0078125,
      5668.0078125,
      10820.0078125
    ],
    "timestamp": "2026-01-20 02:13:04",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"bsz\": 8, \"hidden_size\": 4096, \"intermediate_size\": 11008, \"hidden_act\": \"gelu_pytorch_tanh\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "geglu",
    "kernel_provider": "huggingface",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      2082.00927734375,
      3436.00830078125,
      6356.00830078125,
      12196.0078125
    ],
    "y_values_20": [
      2082.00927734375,
      3436.00830078125,
      6356.00830078125,
      12196.0078125
    ],
    "y_values_80": [
      2082.00927734375,
      3436.00830078125,
      6356.00830078125,
      12196.0078125
    ],
    "timestamp": "2026-01-20 02:13:11",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"bsz\": 8, \"hidden_size\": 4096, \"intermediate_size\": 11008, \"hidden_act\": \"gelu_pytorch_tanh\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  }
]

Comment on lines +39 to +40
# TODO: we should find a better way to tune this. 1e4 is too large apparently
1e-2 if device != "npu" else 1e4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know what tensor couldn't pass with this tolerance? gradients or inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the question. I double-checked which tensors require the large tolerance.

On NPU with bfloat16:

  • Forward outputs (y1 vs y2) differ at around O(1e2).
  • Weight gradients (gate_proj / up_proj / down_proj) are also at O(1e2).
  • The largest discrepancy is in the input gradients: x1.grad vs x2.grad can reach O(1e4).

So the forward and weight gradients are already numerically different at ~1e2, and the input gradients further amplify this difference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

================================================================================
SUMMARY - Minimum atol needed for each tensor (rtol=1e-2):
================================================================================
output                        : min_atol=1e2   , max_abs_diff=2.048000e+03
gate_proj.weight.grad         : min_atol=1e3   , max_abs_diff=2.048000e+03
up_proj.weight.grad           : min_atol=1e2   , max_abs_diff=2.048000e+03
down_proj.weight.grad         : min_atol=1e2   , max_abs_diff=2.048000e+03
input.grad                    : min_atol=1e4   , max_abs_diff=4.096000e+03

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also worth noting: the tolerances used here are consistent with the previous NPU GEGLU kernel implementation, so this change does not introduce new numerical error compared to the existing behavior on NPU.

Copy link
Collaborator

@Tcc0403 Tcc0403 Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem might be related to the implementation of npu's GEGLU. I guess it is performed in bf16 not fp32?

In liger the up projection is implicitly converted to fp32 before activation function gelu, and then converted back to bf16 before multiplying the gate projection.

Run:

import torch
a = torch.randn(1024, dtype=torch.bfloat16).cuda()  # replace .cuda() with equivalent method on npu
gelu = torch.nn.GELU(approximate="tanh")
# pure bf16
b = gelu(a)
# upcast to fp32 and downcast back
c = gelu(a.float()).to(torch.bfloat16)

# Test identical result
torch.testing.assert_close(b, c, rtol=0, atol=0)  # This is true on cuda

If the above script doesn't raise any error, showing gelu is also performed in fp32 on npu, then it's just numerical issues on gemm and errors are amplified due to large intermediate size. I don't mind keeping large atol in this case.

Otherwise, we probably want to modify gelu part to match torch's result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested the standalone GELU behavior on NPU with bf16 input:

a = torch.randn(1024, dtype=torch.bfloat16).npu()
gelu = torch.nn.GELU(approximate="tanh")
b = gelu(a)
c = gelu(a.float()).to(torch.bfloat16)
torch.testing.assert_close(b, c, rtol=0, atol=0)

This test passes without error, which indicates that GELU on NPU is not computed in pure bf16, but matches the fp32 → bf16 behavior of PyTorch CUDA.

Therefore, the mismatch we observe in GEGLU is unlikely to be caused by the GELU implementation itself. It is more likely due to numerical differences in the bf16 GEMM, where errors can be amplified by large intermediate activations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

================================================================================
SUMMARY - Minimum atol needed for each tensor (rtol=1e-2):
================================================================================
output                        : min_atol=1e2   , max_abs_diff=2.048000e+03
gate_proj.weight.grad         : min_atol=1e3   , max_abs_diff=2.048000e+03
up_proj.weight.grad           : min_atol=1e2   , max_abs_diff=2.048000e+03
down_proj.weight.grad         : min_atol=1e2   , max_abs_diff=2.048000e+03
input.grad                    : min_atol=1e4   , max_abs_diff=4.096000e+03

Try rtol=1e-1 and rerun this anaylsis

Copy link
Contributor Author

@noemotiovon noemotiovon Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With pleasure. Below is a summary of the minimum atol required for each tensor when rtol=0.1:

================================================================================
SUMMARY — Minimum atol required per tensor (rtol=0.1)
================================================================================
output                        : min_atol=1e1   , max_abs_diff=2.048000e+03
gate_proj.weight.grad         : min_atol=4e2   , max_abs_diff=2.048000e+03
up_proj.weight.grad           : min_atol=8e0   , max_abs_diff=2.048000e+03
down_proj.weight.grad         : min_atol=1e1   , max_abs_diff=2.048000e+03
input.grad                    : min_atol=2e3   , max_abs_diff=4.096000e+03
================================================================================

- Refactor Ascend GEGLU kernels to use flatten 1D grid-stride loop pattern
  instead of row-based tiling approach for better performance
- Simplify block size calculation using compute_default_tiling_strategy
- Align type conversion logic with GPU version for consistency
- Update test tolerances for NPU bfloat16 (1e4) to handle precision differences
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants